# -*- coding: utf-8 -*-
"""
Created on Fri Oct  1 18:38:08 2021
因 tf2 运行速度慢
将 神经网络 从 tf2 中解放出来
直接用 矩阵 来作为 神经网络
@author: Waikikilick
"""

import tensorflow as tf
from tensorflow.keras import layers, Sequential
import numpy as np
import random
import warnings
from tensorflow import keras
from scipy.linalg import expm
import warnings
warnings.filterwarnings("ignore")
from time import *

tf.random.set_seed(1)
np.random.seed(1)
random.seed (1)

class env(object):
    def  __init__(self, 
        action_space = [0,1], #允许的动作，默认两个分立值，只是默认值，真正值由调用时输入
        dt = 0.1,
        ): 
        self.action_space = action_space
        self.n_actions = len(self.action_space)
        self.n_features = 4 #描述状态所用的长度
        self.target_psi =  np.mat([[1], [0]], dtype=complex) #最终的目标态为 |0>,  np.array([1,0,0,0])
        self.s_x = np.mat([[0, 1], [1, 0]], dtype=complex)
        self.s_z = np.mat([[1, 0], [0, -1]], dtype=complex)
        self.dt = dt
        self.training_set, self.validation_set, self.testing_set = self.psi_set()
        
    def psi_set(self):
        
        theta_num = 6 # 除了 0 和 Pi 两个点之外，点的数量
        varphi_num = 21 # varphi 角度一圈上的点数
        # 总点数为 theta_num * varphi_num + 2(布洛赫球两极) # 6 * 21 + 2 = 512
        
        theta = np.delete(np.linspace(0,np.pi,theta_num+1,endpoint=False),[0])
        varphi = np.linspace(0,np.pi*2,varphi_num,endpoint=False) 
        
        psi_set = []
        for ii in range(theta_num):
            for jj in range(varphi_num):
                psi_set.append(np.mat([[np.cos(theta[ii]/2)],[np.sin(theta[ii]/2)*(np.cos(varphi[jj])+np.sin(varphi[jj])*(0+1j))]]))
        psi_set.append(np.mat([[1], [0]], dtype=complex))
        psi_set.append(np.mat([[0], [1]], dtype=complex))
        random.shuffle(psi_set) # 打乱点集
    
        training_set = psi_set[0:32]
        validation_set = psi_set[32:64]
        testing_set = psi_set[64:128]
        
        return training_set, validation_set, testing_set
    
    def reset(self, init_psi): # 在一个新的回合开始时，归位到开始选中的那个点上
        
        init_state = np.array([init_psi[0,0].real, init_psi[1,0].real, init_psi[0,0].imag, init_psi[1,0].imag])
        # np.array([1实，2实，1虚，2虚])
        return init_state
    
    def init_fid(self, state):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi) 
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        
        return fid
    
    def step(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi) 
        
        H =  float(action)* self.s_z + 1 * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  # next state

        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
                
        done = (fid>=0.9999 or nstep >= 20) 

        #再将量子态的 psi 形式恢复到 state 形式。 # 因为网络输入不能为复数，否则无法用寻常基于梯度的算法进行反向传播
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, done, fid  
    
def relu(vector):
    for i in range(len(vector)):
        vector[i] = np.maximum(0,vector[i])
    return vector

def agent_matrix(x):
    out = np.mat(x) * net[0].numpy() + net[1].numpy()
    out = relu(out)
    out = out * net[2].numpy() + net[3].numpy()
    out = relu(out)
    out = out * net[4].numpy() + net[5].numpy()
    out = relu(out)
    action = np.argmax(out)
    return action

def state_evolution(psi, action_list):
    
    psi_list = []
    state_list = []
    psi_list.append(psi)
    
    for action in action_list:
        
        for i in range(10):
            
            H =  float(action)* env.s_z + 1 * env.s_x
            U = expm(-1j * H * env.dt/10) 
            psi = U * psi  # next state
            psi_list.append(psi)
    
    for psi in psi_list:
        
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 
        state_list.append(state)
            
    return state_list

def testing_point(init_psi):
    fid_list_tem = []
    action_list = []
    observation = env.reset(init_psi)
    fid_list_tem.append(env.init_fid(observation))
    nstep = 0 
    while True:
        action = agent_matrix(observation) 
        observation, done, fid = env.step(observation, action, nstep)  
        fid_list_tem.append(fid)
        action_list.append(action)
        
        nstep += 1
        if done:
            break
    fid_max = max(fid_list_tem)
    max_index = fid_list_tem.index(max(fid_list_tem))
    action_list = action_list[0:max_index]
    fid_list_tem = fid_list_tem[0:max_index+1]
        
    return fid_max, fid_list_tem, action_list

def positions(state_list):
    #输入 states 的矩阵，行数为位置数，共 4 列标志着量子态的向量表示 [[1+2j],[3+4j]] 表示为：[1,3,2,4]
    #所以先将 state 表示变为 psi 表示 
    # b 矩阵第一列为 alpha,第二列为 beta
    b = np.zeros((state_list.shape[0],2),complex) 
    b[:,0] = state_list[:,0] + state_list[:,2]*1j
    b[:,1] = state_list[:,1] + state_list[:,3]*1j
    alpha = b[:,0]
    beta = b[:,1]
    
    #根据 alpha 和 beta 求直角坐标系下量子态的坐标
    z = 2*(alpha*alpha.conj())-1 #后面表示 z 的列向量中有多余的 -1 量，就是这里的原因。
    x = (beta*np.sqrt(2*(z+1))+beta.conj()*np.sqrt(2*(z+1)))/2
    y = (beta*np.sqrt(2*(z+1))-beta.conj()*np.sqrt(2*(z+1)))/(2*1j)
    
    #positions 矩阵为 位置数*行，3列，分别为 x,y,z
    positions = np.zeros((state_list.shape[0],3))
    positions[:,0], positions[:,1],positions[:,2] = x,y,z
    return positions

if __name__ == "__main__":
    
    dt = np.pi/10
    network = keras.models.load_model('dqn_tf2_single_qubit_to_0_saved_model') #导入训练好的网络 
    net = network.trainable_variables
    env = env(action_space = list(range(4)),   #允许的动作数 0 ~ 4-1 也就是 4 个
               dt = dt)
    testing_set = env.testing_set
    
    test_psi = testing_set[33]
    
    print('量子态为：', test_psi)
    fid_max, fid_list_tem, action_list = testing_point(test_psi)
    print('最大保真度为: ',fid_max)
    print('保真度记录为：', fid_list_tem)
    print('动作纪录为：', action_list)
    print('所用动作数为：', len(action_list))
    
    state_list = state_evolution(test_psi, action_list)
    
    positions = positions(np.array(state_list))
    print(positions)
   
# 量子态为： [[ 0.90096887+0.j        ]
#            [-0.31805929+0.29511589j]]
# 最大保真度为:  0.9998718319486674
# 保真度记录为： [0.8117449009293668, 0.9084932786565539, 0.8492111081128197, 0.6174650139286879, 0.6970519023901329, 0.9103820371112604, 0.9998718319486674]
# 动作纪录为： [0, 0, 3, 3, 0, 2]
# 所用动作数为： 6

# positions: 
    # array([[-0.57312303,  0.53178046,  0.6234898 ],
    #        [-0.57848623,  0.48525912,  0.65565026],
    #        [-0.58276567,  0.4368677 ,  0.68522316],
    #        [-0.58594236,  0.38677746,  0.7120918 ],
    #        [-0.58800173,  0.33516701,  0.73615015],
    #        [-0.58893364,  0.2822215 ,  0.75730324],
    #        [-0.58873243,  0.22813181,  0.7754676 ],
    #        [-0.58739696,  0.17309373,  0.79057155],
    #        [-0.58493056,  0.11730711,  0.80255547],
    #        [-0.58134104,  0.06097503,  0.81137207],
    #        [-0.57664066,  0.0043029 ,  0.81698656],
    #        [-0.57084604, -0.05250241,  0.81937677],
    #        [-0.56397813, -0.10923341,  0.81853328],
    #        [-0.55606209, -0.16568288,  0.81445941],
    #        [-0.5471272 , -0.22164477,  0.80717124],
    #        [-0.53720674, -0.27691506,  0.79669754],
    #        [-0.52633781, -0.33129261,  0.78307963],
    #        [-0.51456126, -0.38458001,  0.76637127],
    #        [-0.50192142, -0.43658447,  0.74663839],
    #        [-0.48846601, -0.48711858,  0.72395887],
    #        [-0.47424587, -0.53600118,  0.69842222],
    #        [-0.40683915, -0.62427028,  0.66690968],
    #        [-0.33576038, -0.70283795,  0.62712342],
    #        [-0.26242347, -0.77071649,  0.58062898],
    #        [-0.18819559, -0.82732742,  0.52925586],
    #        [-0.11432892, -0.87251341,  0.47502552],
    #        [-0.04189881, -0.9065231 ,  0.42007185],
    #        [ 0.02824678, -0.92996663,  0.36655721],
    #        [ 0.09552095, -0.9437416 ,  0.31658734],
    #        [ 0.15960929, -0.94893149,  0.27212848],
    #        [ 0.22045466, -0.94668243,  0.23493003],
    #        [ 0.27821292, -0.93806909,  0.2064557 ],
    #        [ 0.33318133, -0.92396516,  0.18782593],
    #        [ 0.38570646, -0.90493752,  0.17977376],
    #        [ 0.43608353, -0.88118246,  0.18261604],
    #        [ 0.48446269, -0.85251709,  0.19624094],
    #        [ 0.53077742, -0.81842892,  0.22011232],
    #        [ 0.57470612, -0.77817518,  0.25329088],
    #        [ 0.61567118, -0.73091434,  0.29447109],
    #        [ 0.65287258, -0.67584844,  0.34203256],
    #        [ 0.68534826, -0.61235607,  0.3941038 ],
    #        [ 0.65597113, -0.60822873,  0.44694483],
    #        [ 0.62380597, -0.6023622 ,  0.49802197],
    #        [ 0.58897116, -0.59477538,  0.54713365],
    #        [ 0.55159612, -0.58549423,  0.59408604],
    #        [ 0.51182083, -0.57455167,  0.63869384],
    #        [ 0.46979511, -0.56198734,  0.68078101],
    #        [ 0.42567808, -0.54784747,  0.72018145],
    #        [ 0.37963738, -0.53218469,  0.75673966],
    #        [ 0.33184856, -0.51505774,  0.79031137],
    #        [ 0.28249428, -0.49653128,  0.82076407],
    #        [ 0.26273112, -0.45799367,  0.84924328],
    #        [ 0.24123638, -0.41637536,  0.87660514],
    #        [ 0.21809362, -0.37184819,  0.90231042],
    #        [ 0.19341428, -0.32463509,  0.92585257],
    #        [ 0.16733682, -0.27500806,  0.94676763],
    #        [ 0.140025  , -0.22328457,  0.96464346],
    #        [ 0.11166531, -0.16982252,  0.97912776],
    #        [ 0.08246382, -0.11501394,  0.9899351 ],
    #        [ 0.05264234, -0.05927776,  0.99685251],
    #        [ 0.02243422, -0.00305164,  0.99974366]])


 